
#ifndef STATSxx_MACHINE_LEARNING_RESTRICTED_BOLTZMANN_MACHINE_RBM_HPP
#define STATSxx_MACHINE_LEARNING_RESTRICTED_BOLTZMANN_MACHINE_RBM_HPP


// STL
#include <tuple>   // std::tuple<>
#include <utility> // std::pair<>

// jScience
#include "jScience/linalg/Matrix.hpp" // Matrix<>
#include "jScience/linalg/Vector.hpp" // Vector<>
#include "jScience/stl/ostream.hpp"   // NullStream

// stats++
//#include "statsxx/machine_learning/neural_network/deep_belief_network/DBN.hpp" // neural_network::DBN

// the DBN class (namespace included) needs to be forward declared, so we can make it a friend class
namespace neural_network
{
    class DBN;
}

namespace restricted_Boltzmann_machine
{
    //=========================================================
    // RESTRICTED BOLTZMANN MACHINE
    //=========================================================
    class RBM
    {

        // the DBN class needs access to the RBMs weights and biases 
        friend class neural_network::DBN;
        
    public:
        
        RBM(const int nv,
            const int nh);
        ~RBM();
        
        double initialize_weights(
                                  const int ntrials,        // number of trials to conduct
                                  const Matrix<double> &X,  // data
                                  const bool x_binomial     // whether the data is binomial
                                  );
        
        
        
        void stochastic_gradient(
                                 const int max_epochs,             // maximum number of epochs
                                 const int nbatches,               // number of batches per epoch
                                 // -----
                                 double lr,                        // learning rate
                                 double momentum,                  // momentum
                                 // -----
                                 const double weight_penalty,      // weight penalty
                                 // -----
                                 const double sparsity_target,     // sparsity target
                                 const double sparsity_penalty,    // sparsity penalty
                                 const double sparsity_multiplier, // multiplier for q (hidden activation) EMA
                                 const double q_min,               // minimum (desired) q
                                 const double q_max,               // maximum (desired) q
                                 const double q_penalty,           // q penalty
                                 // -----
                                 const int K_begin,                // number of starting Monte Carlo (MC) itertions
                                 const int K_end,                  // number of ending Monte Carlo itertions
                                 const double K_rate,              // rate of updates to number of MC iterations
                                 // -----
                                 const bool mean_field,            // mean-field approximation
                                 // -----
                                 const double convg_criterion,     // convergence criterion
                                 const int max_no_improvement,     // max number of epochs to go without an improvement in convergence
                                 // -----
                                 const Matrix<double> &X,          // data
                                 const bool x_binomial,            // whether the data is binomial
                                 // -----                                 
                                 std::ostream* os = &NullStream    // (optional) output stream
                                 );
        
        std::tuple<
                   Vector<double>, // visible units (reconstructed) 
                   Vector<double>, // hidden units
                   double          // reconstruction error
                   > reconstruct(
                                 const Vector<double> &x,   // data
                                 const bool x_binomial      // whether the data is binomial
                                 ) const;
        
        Vector<double> Gibbs(
                             const int K,                    // number of Markov itertions
                             Vector<double> v,               // initial vector v
                             const bool mean_field = false   // mean-field approximation
                             ) const;
        
        int get_nvis() const;
        int get_nhid() const;
        
        Matrix<double> get_W() const;
        Vector<double> get_b() const;
        Vector<double> get_c() const;
        
    private:

        int nvis; // number of visible units
        int nhid; // number of hidden units
        
        Matrix<double> W; // weight matrix
        Vector<double> b; // visible neuron biases
        Vector<double> c; // hidden neuron biases
    
        
        std::tuple<
                   Matrix<double>,
                   Vector<double>,
                   Vector<double>,
                   Vector<double>,
                   double
                   > grad_log_likelihood(
                                         const int K,               // number of Monte Carlo itertions
                                         const bool mean_field,     // mean-field approximation
                                         const Vector<double> &x,   // data
                                         const bool x_binomial      // whether the data is binomial
                                         );
    
        Vector<double> Gibbs_vhv(
                                 Vector<double> v,              // initial visible data
                                 const bool mean_field = false  // mean-field approximation
                                 ) const;
        
    };
}

#include "statsxx/machine_learning/restricted_Boltzmann_machine/src/RBM.cpp"
#include "statsxx/machine_learning/restricted_Boltzmann_machine/src/RBM_initialize_weights.cpp"
#include "statsxx/machine_learning/restricted_Boltzmann_machine/src/RBM_stochastic_gradient.cpp"
#include "statsxx/machine_learning/restricted_Boltzmann_machine/src/RBM_reconstruct.cpp"
#include "statsxx/machine_learning/restricted_Boltzmann_machine/src/RBM_get_nvis.cpp"
#include "statsxx/machine_learning/restricted_Boltzmann_machine/src/RBM_get_nhid.cpp"
#include "statsxx/machine_learning/restricted_Boltzmann_machine/src/RBM_get_W.cpp"
#include "statsxx/machine_learning/restricted_Boltzmann_machine/src/RBM_get_b.cpp"
#include "statsxx/machine_learning/restricted_Boltzmann_machine/src/RBM_get_c.cpp"
#include "statsxx/machine_learning/restricted_Boltzmann_machine/src/RBM_grad_log_likelihood.cpp"
#include "statsxx/machine_learning/restricted_Boltzmann_machine/src/RBM_Gibbs.cpp"
#include "statsxx/machine_learning/restricted_Boltzmann_machine/src/RBM_Gibbs_vhv.cpp"


#endif
